from mindspore.nn.layer.activation import ReLU
import numpy as np
import time
import mindspore as ms
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import dataset as ds
from mindspore.train import Model
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, SummaryCollector

from cybertroncode.cybertron import Cybertron
from cybertroncode.models import MolCT,SchNet
from cybertroncode.blocks import MLP
from cybertroncode.readouts import GraphReadout

from alphachem import AlphaChemWithMemorySet

class ProbNet(nn.Cell):
    def __init__(self,
        n_in=2,
        n_out=1,
        hidden_layers=[16,16],
        activation='swish',
        coef_init=Tensor(10,ms.float32)
    ):
        super().__init__()
        self.coef = ms.Parameter(coef_init,'coef')

        self.mlp = MLP(n_in,n_out,hidden_layers,activation=activation)

    def construct(self, x):
        return self.coef * self.mlp(x)


if __name__ == '__main__':

    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
    # context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")

    pos_file = 'sn2-kcal_per_mol-A-pos.npz'
    neg_file = 'sn2-kcal_per_mol-A-neg.npz'
    
    pos_data = np.load(pos_file)
    neg_data = np.load(neg_file)

    atom_types = pos_data['Z']

    mod = MolCT(
        min_rbf_dis=0.1,
        max_rbf_dis=10,
        num_rbf=64,
        rbf_sigma=0.2,
        n_interactions=3,
        dim_feature=128,
        n_heads=8,
        max_cycles=1,
        fixed_cycles=True,
        self_dis=0.1,
        unit_length='A',
        )

    # mod = SchNet(
    #     min_rbf_dis=0.1,
    #     max_rbf_dis=10,
    #     num_rbf=64,
    #     rbf_sigma=0.2,
    #     n_interactions=3,
    #     dim_feature=128,
    #     dim_filter=128,
    #     unit_length='A',
    #     activation='swish'
    #     )

    readout = GraphReadout(n_in=mod.dim_feature,n_interactions=mod.n_interactions,activation=mod.activation,n_out=1,unit_energy=None)
    net = Cybertron(mod,atom_types=atom_types,full_connect=True,readout=readout,unit_dis='A',unit_energy=None)

    network_name = mod.network_name

    tot_params = 0
    for i,param in enumerate(net.get_parameters()):
        tot_params += param.size
        print(i,param.name,param.shape)
    print('Total parameters: ',tot_params)
    net.print_info()

    n_epoch = 1
    repeat_time = 1
    batch_size = 32
    num_contrast = 40

    n_pos = pos_data['R'].shape[0]
    n_neg = neg_data['R'].shape[0]

    nmax = max(n_pos,n_neg)
    pos_repeat = int(np.floor(nmax/n_pos))
    # print('ds_pos.repeat: ',pos_repeat)
    neg_repeat = int(np.floor(nmax/n_neg))
    # print('ds_neg.repeat: ',neg_repeat)

    # 6310
    ds_pos = ds.NumpySlicesDataset({'Rp':pos_data['R'],'Vp':pos_data['V'],'Mp':pos_data['M']},shuffle=True)
    ds_pos = ds_pos.repeat(pos_repeat)

    # 788254
    ds_neg = ds.NumpySlicesDataset({'Rn':neg_data['R'],'Vn':neg_data['V'],'Mn':neg_data['M']},shuffle=True)
    ds_neg = ds_neg.repeat(neg_repeat)

    dataset = ds.zip((ds_pos,ds_neg))
    dataset = dataset.batch(batch_size,drop_remainder=True)

    prob_net = ProbNet(2,1,[16,16],activation='swish')

    n_init = batch_size * num_contrast

    idx_pos = np.arange(n_pos)
    np.random.shuffle(idx_pos)
    R_pos_init = Tensor(pos_data['R'][idx_pos[0:n_init]],ms.float32)
    V_pos_init = Tensor(pos_data['V'][idx_pos[0:n_init]],ms.float32)

    idx_neg = np.arange(n_neg)
    np.random.shuffle(idx_neg)
    R_neg_init = Tensor(neg_data['R'][idx_neg[0:n_init]],ms.float32)
    V_neg_init = Tensor(neg_data['V'][idx_neg[0:n_init]],ms.float32)

    loss_network = AlphaChemWithMemorySet(
        net,
        prob_net,
        batch_size=batch_size,
        num_contrast=num_contrast,
        R_pos_init=R_pos_init,
        V_pos_init=V_pos_init,
        R_neg_init=R_neg_init,
        V_neg_init=V_neg_init,
    )

    lr = 1e-4
    optim = nn.Adam(params=loss_network.trainable_params(),learning_rate=lr,weight_decay=1e-4)
    
    model = Model(loss_network,optimizer=optim,amp_level='O0')

    outdir = 'sn2_' + network_name

    params_name = 'sn2-alphachem-' + network_name
    config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=32)
    ckpoint_cb = ModelCheckpoint(prefix=params_name, directory=outdir, config=config_ck)

    summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_freq=8)

    print("Start training ...")
    beg_time = time.time()
    model.train(n_epoch,dataset,callbacks=[LossMonitor(8),summary_collector,ckpoint_cb],dataset_sink_mode=False)
    end_time = time.time()
    used_time = end_time - beg_time
    m, s = divmod(used_time, 60)
    h, m = divmod(m, 60)
    print ("Training Fininshed!")
    print ("Training Time: %02d:%02d:%02d" % (h, m, s))